(ns proteins
(:require
[scicloj.tempfiles.api :as tempfiles]
[tablecloth.api :as tc]
[fastmath.core :as math]
[fastmath.random :as random]
[tech.v3.datatype :as dtype]
[tech.v3.dataset :as dataset]
[tech.v3.dataset.tensor :as dataset.tensor]
[tech.v3.tensor :as tensor]
[tech.v3.datatype.functional :as fun]
[aerial.hanami.common :as hc]
[aerial.hanami.templates :as ht]
[scicloj.kindly.v3.kind :as kind]
[scicloj.kindly.v3.api :as kindly]
[scicloj.clay.v2.api :as clay]
[libpython-clj2.python :refer [py. py.. py.-] :as py]
[scicloj.noj.v1.vis :as vis]
[scicloj.noj.v1.vis.python :as vis.python]
[libpython-clj2.require :refer [require-python]]))
…
https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_overview.html
(require-python '[builtins :as python]
'operator
'[arviz :as az]
'[arviz.style :as az.style]
'[pandas :as pd]
'[matplotlib.pyplot :as plt]
'[numpy :as np]
'[numpy.random :as np.random]
'[pymc :as pm]
'[Bio.PDB.PDBParser]
'[Bio.PDB]
'[Bio.PDB.Polypeptide]
'[pytensor]
'[pytensor.tensor :as pt]
'[math]):ok(defn brackets [obj entry]
(py. obj __getitem__ entry))
…
(def colon
(python/slice nil nil))
…
(arviz.style/use "arviz-darkgrid")nil(defn extract-coordinates-from-pdb
([protein-name]
(extract-coordinates-from-pdb protein-name :models))
([protein-name data-type]
(let [filepath (str "data/" protein-name ".pdb")
parser (Bio.PDB/PDBParser)
structure (py. parser get_structure protein-name filepath)]
(case data-type
:models (-> structure
(->> (map
(fn [model]
(-> model
(->> (mapcat
(fn [chain]
(->> chain
(filter (fn [residue]
(-> residue
(py. get_resname)
(Bio.PDB.Polypeptide/is_aa :standard true))))
(map (fn [residue]
(-> residue
(brackets "CA")
(py. get_coord)
(->> (dtype/->array :float32)))))))))))))
(tensor/->tensor :datatype :float32))))))
…
(comment
(-> "1d3z"
extract-coordinates-from-pdb)
(->> "1ubq"
extract-coordinates-from-pdb))
…
(defn center-1d [xs]
(fun/- xs
(fun/mean xs)))
…
(defn center-columns [xyzs]
(-> xyzs
(tensor/map-axis center-1d 0)))
…
(comment
(-> [[1 2 3]
[4 5 9]]
tensor/->tensor
center-columns))
…
(defn center-columns [xyzs]
(-> xyzs
(tensor/map-axis center-1d 0)))
…
(defn read-data [prots
{:keys [data-type models rmsd?]
:or {data-type :models
models [0 1]
rmsd true}}]
(case data-type
:models (let [coords (map (fn [prot model]
(-> prot
extract-coordinates-from-pdb
(nth model)))
prots
models)
obs (->> coords
(mapv #(tensor/map-axis % center-1d 0)))
obs-datasets (->> obs
(mapv #(-> %
dataset.tensor/tensor->dataset
(tc/rename-columns [:x :y :z]))))]
{:coords coords
:obs obs
:obs-datasets obs-datasets})))
…
(let [name1 "1d3z"
name2 "1ubq"
models [4 0]
samples 100]
(->> (read-data [name1 name2]
{:models models})
:obs-datasets
(map tc/info)))
…
(let [name1 "1d3z"
name2 "1ubq"
models [4 0]
samples 100
{:keys [obs-datasets]} (read-data [name1 name2]
{:models models})]
(kind/hiccup
['(fn [{:keys [datasets]}]
[plotly
{:data (->> datasets
(mapv (fn [dataset]
(->> dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 0.6
:line {:width 10}
:marker {:size 4}})))))}])
{:datasets (->> obs-datasets
(mapv #(update-vals % vec)))}]))
…
(defn ->max-distance-to-origin [centered-structure]
(-> centered-structure
fun/sq
(tensor/reduce-axis fun/sum 1)
fun/sqrt
fun/reduce-max))
…
(defn ->average-structure [centered-structures]
(-> centered-structures
(->> (apply fun/+))
(fun// (count centered-structures))))
…
trying PyTensor https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_pytensor.html
(let [x (pt/scalar :name "x")
y (pt/scalar :name "y")
z (operator/add x y)
w (pt/mul z 2)
f (pytensor/function :inputs [x y]
:outputs w)]
(f :x 10
:y 5))30.0(def results
(let [name1 "1d3z"
name2 "1ubq"
models [4 0]
samples 100
{:keys [obs obs-datasets]}
(read-data [name1 name2]
{:models models})
max-distance (->max-distance-to-origin (obs 0))
average-structure (->average-structure obs)
shape (dtype/shape (obs 0))]
(py/with [model (pm/Model)]
(let [M (pm/Normal "M" :shape shape)
M0 (pm/Deterministic "M0"
(operator/sub
M
(pt/mean M)))
t1 (pm/Normal "t1" :shape [(shape 1)])
t2 (pm/Normal "t2" :shape [(shape 1)])
u (pm/Uniform "u0" :shape [(shape 1)])
theta1 (-> u
(brackets 1)
(operator/mul 2)
(operator/mul math/PI))
theta2 (-> u
(brackets 2)
(operator/mul 2)
(operator/mul math/PI))
r1 (-> u
(brackets 0)
(->> (operator/sub 1))
pt/sqrt)
r2 (-> u
(brackets 0)
pt/sqrt)
w (-> theta2
(pt/cos)
(operator/mul r2))
x (-> theta1
(pt/sin)
(operator/mul r1))
y (-> theta1
(pt/cos)
(operator/mul r1))
z (-> theta2
(pt/sin)
(operator/mul r2))
R00 (operator/sub (operator/add (pt/sqr w)
(pt/sqr x))
(operator/add (pt/sqr y)
(pt/sqr z)))
R11 (operator/sub (operator/add (pt/sqr w)
(pt/sqr y))
(operator/add (pt/sqr x)
(pt/sqr z)))
R22 (operator/sub (operator/add (pt/sqr w)
(pt/sqr z))
(operator/add (pt/sqr x)
(pt/sqr y)))
R01 (operator/mul 2
(operator/sub (operator/mul x y)
(operator/mul w z)))
R02 (operator/mul 2
(operator/add (operator/mul x z)
(operator/mul w y)))
R10 (operator/mul 2
(operator/add (operator/mul x y)
(operator/mul w z)))
R12 (operator/mul 2
(operator/sub (operator/mul y z)
(operator/mul w x)))
R20 (operator/mul 2
(operator/sub (operator/mul x z)
(operator/mul w y)))
R21 (operator/mul 2
(operator/add (operator/mul y z)
(operator/mul w x)))
R (pm/Deterministic "R"
(pt/stack [(pt/stack [R00 R01 R02])
(pt/stack [R10 R11 R12])
(pt/stack [R20 R21 R22])]))
U (pm/HalfNormal "U"
:sigma 0.01
:shape (shape 0))
Q1 (pm/Normal "Q1" :shape shape)
Q2 (pm/Normal "Q2" :shape shape)
debug (pm/Deterministic "debug"
(-> M0
(pt/dot R)
(pt/add t1)))
prior-predictive-samples (pm/sample_prior_predictive)]
{:prior-predictive-samples prior-predictive-samples}))))
…
(-> results
:prior-predictive-samples
(py.- prior)
(py.- "M0")
np/mean)<xarray.DataArray 'M0' ()>
array(-1.2465662e-18)(-> results
:prior-predictive-samples
(py.- prior)
(py.- "debug"))<xarray.DataArray 'debug' (chain: 1, draw: 500, debug_dim_0: 76, debug_dim_1: 3)>
array([[[[ 0.16706071, 2.6997921 , 3.80579563],
[-0.66583907, 0.11061794, 2.35875333],
[-0.33844605, 3.2677548 , 4.74898799],
...,
[-1.05468472, 1.20188624, 2.56099598],
[-2.39869129, 2.35258191, 0.60791435],
[-2.87583178, 0.60094867, 2.95815657]],
[[ 1.2683437 , 0.08588472, -1.20021449],
[-0.15936383, 0.32123272, 0.1536795 ],
[-0.17914647, -0.94234472, -1.39809149],
...,
[ 0.3398098 , -0.72257216, -0.05470593],
[ 0.93696573, -0.19657193, 0.44033832],
[-1.15839049, -2.90982475, -0.5328221 ]],
[[-4.51618532, 1.59710832, -0.17286661],
[-3.50154325, 0.79761203, -1.26850329],
[-1.18601627, -0.30571579, -1.08569199],
...,
...
...,
[-0.77592396, 2.29038249, -1.87501885],
[-1.69725733, -0.56947283, -1.97132729],
[-0.13016795, 1.55273405, -1.51042059]],
[[ 0.40471076, -0.97868234, -0.82151834],
[-0.16058526, 1.07207209, -0.25623372],
[-0.51692006, 1.04154915, -1.78367788],
...,
[ 0.74325271, -1.1145294 , 0.32821188],
[ 1.4166422 , 0.70435817, -0.76661303],
[ 0.68884763, -0.5847596 , -1.04314515]],
[[-0.27730378, 2.17870692, 0.53384238],
[ 0.50581583, 0.94524572, -0.2116013 ],
[-0.46537371, 1.8002356 , -0.72092065],
...,
[-1.53258094, 0.16178839, -1.10599859],
[-0.01769912, 0.63115223, -0.5830774 ],
[-0.32403442, 0.87416766, 1.37457161]]]])
Coordinates:
* chain (chain) int64 0
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* debug_dim_0 (debug_dim_0) int64 0 1 2 3 4 5 6 7 ... 68 69 70 71 72 73 74 75
* debug_dim_1 (debug_dim_1) int64 0 1 2